﻿#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#endregion

using System.IO.Pipelines;
using System.Text;
using Grpc.AspNetCore.Server.Tests.Infrastructure;
using Grpc.AspNetCore.Web;
using Grpc.AspNetCore.Web.Internal;
using Grpc.Tests.Shared;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Logging.Testing;
using Microsoft.Extensions.Options;
using NUnit.Framework;

namespace Grpc.AspNetCore.Server.Tests.Web;

[TestFixture]
public class GrpcWebMiddlewareTests
{
    [Test]
    public async Task Invoke_NonGrpcWebContentType_NotProcessed()
    {
        // Arrange
        var testSink = new TestSink();
        var testLoggerFactory = new TestLoggerFactory(testSink, true);

        var middleware = CreateMiddleware(logger: testLoggerFactory.CreateLogger<GrpcWebMiddleware>());
        var httpContext = new DefaultHttpContext();

        // Act
        await middleware.Invoke(httpContext);

        // Assert
        Assert.AreEqual(0, testSink.Writes.Count);
        Assert.IsNull(httpContext.Features.Get<IHttpResponseTrailersFeature>());
    }

    [TestCase(GrpcWebProtocolConstants.GrpcWebContentType, nameof(ServerGrpcWebMode.GrpcWeb))]
    [TestCase(GrpcWebProtocolConstants.GrpcWebContentType + "+proto", nameof(ServerGrpcWebMode.GrpcWeb))]
    [TestCase(GrpcWebProtocolConstants.GrpcWebTextContentType, nameof(ServerGrpcWebMode.GrpcWebText))]
    [TestCase(GrpcWebProtocolConstants.GrpcWebTextContentType + "+proto", nameof(ServerGrpcWebMode.GrpcWebText))]
    [TestCase(GrpcWebProtocolConstants.GrpcContentType, nameof(ServerGrpcWebMode.None))]
    [TestCase("application/json", nameof(ServerGrpcWebMode.None))]
    [TestCase("", nameof(ServerGrpcWebMode.None))]
    public void GetGrpcWebMode_ContentTypes_Matched(string contentType, string expectedGrpcWebMode)
    {
        // Arrange
        var httpContext = new DefaultHttpContext();
        httpContext.Request.Method = HttpMethods.Post;
        httpContext.Request.ContentType = contentType;

        // Act
        var grpcWebContext = GrpcWebMiddleware.GetGrpcWebContext(httpContext);

        // Assert
        Assert.AreEqual(Enum.Parse<ServerGrpcWebMode>(expectedGrpcWebMode), grpcWebContext.Request);
    }

    [TestCase(GrpcWebProtocolConstants.GrpcWebContentType, null,
        nameof(ServerGrpcWebMode.GrpcWeb), nameof(ServerGrpcWebMode.GrpcWeb))]
    [TestCase(GrpcWebProtocolConstants.GrpcWebContentType, GrpcWebProtocolConstants.GrpcWebTextContentType,
        nameof(ServerGrpcWebMode.GrpcWeb), nameof(ServerGrpcWebMode.GrpcWebText))]
    [TestCase(GrpcWebProtocolConstants.GrpcWebTextContentType, GrpcWebProtocolConstants.GrpcWebTextContentType,
        nameof(ServerGrpcWebMode.GrpcWebText), nameof(ServerGrpcWebMode.GrpcWebText))]
    [TestCase("application/json", null,
        nameof(ServerGrpcWebMode.None), nameof(ServerGrpcWebMode.None))]
    [TestCase("", null,
        nameof(ServerGrpcWebMode.None), nameof(ServerGrpcWebMode.None))]
    public void GetGrpcWebMode_Accept_Matched(
        string? contentType, string? accept,
        string expectedRequestGrpcWebMode, string expectedResponseGrpcWebMode)
    {
        // Arrange
        var httpContext = new DefaultHttpContext();
        httpContext.Request.Method = HttpMethods.Post;
        httpContext.Request.ContentType = contentType!;
        httpContext.Request.Headers["Accept"] = accept;

        // Act
        var grpcWebContext = GrpcWebMiddleware.GetGrpcWebContext(httpContext);

        // Assert
        Assert.AreEqual(Enum.Parse<ServerGrpcWebMode>(expectedRequestGrpcWebMode), grpcWebContext.Request);
        Assert.AreEqual(Enum.Parse<ServerGrpcWebMode>(expectedResponseGrpcWebMode), grpcWebContext.Response);
    }

    [Test]
    public void GetGrpcWebMode_NonPost_NotMatched()
    {
        // Arrange
        var httpContext = new DefaultHttpContext();
        httpContext.Request.Method = HttpMethods.Options;
        httpContext.Request.ContentType = GrpcWebProtocolConstants.GrpcWebContentType;

        // Act
        var grpcWebContext = GrpcWebMiddleware.GetGrpcWebContext(httpContext);

        // Assert
        Assert.AreEqual(ServerGrpcWebMode.None, grpcWebContext.Request);
        Assert.AreEqual(ServerGrpcWebMode.None, grpcWebContext.Response);
    }

    [Test]
    public async Task Invoke_GrpcWebContentTypeAndNotEnabled_NotProcessed()
    {
        // Arrange
        var testSink = new TestSink();
        var testLoggerFactory = new TestLoggerFactory(testSink, true);

        var middleware = CreateMiddleware(logger: testLoggerFactory.CreateLogger<GrpcWebMiddleware>());
        var httpContext = new DefaultHttpContext();
        httpContext.Request.Method = HttpMethods.Post;
        httpContext.Request.ContentType = GrpcWebProtocolConstants.GrpcWebContentType;

        // Act
        await middleware.Invoke(httpContext);

        // Assert
        Assert.IsNull(httpContext.Features.Get<IHttpResponseTrailersFeature>());

        Assert.AreEqual(2, testSink.Writes.Count);
        var writes = testSink.Writes.ToList();
        Assert.AreEqual("DetectedGrpcWebRequest", writes[0].EventId.Name);
        Assert.AreEqual("GrpcWebRequestNotProcessed", writes[1].EventId.Name);
    }

    [Test]
    public async Task Invoke_GrpcWebContentTypeAndEnabled_Processed()
    {
        // Arrange
        var testSink = new TestSink();
        var testLoggerFactory = new TestLoggerFactory(testSink, true);

        var middleware = CreateMiddleware(
            options: new GrpcWebOptions { DefaultEnabled = true },
            logger: testLoggerFactory.CreateLogger<GrpcWebMiddleware>());
        var httpContext = new DefaultHttpContext();
        httpContext.Request.Method = HttpMethods.Post;
        httpContext.Request.ContentType = GrpcWebProtocolConstants.GrpcWebContentType;

        // Act
        await middleware.Invoke(httpContext);

        // Assert
        Assert.AreEqual(1, testSink.Writes.Count);
        var writes = testSink.Writes.ToList();
        Assert.AreEqual("DetectedGrpcWebRequest", writes[0].EventId.Name);
    }

    [Test]
    public async Task Invoke_GrpcWebContentTypeAndMetadata_Processed()
    {
        // Arrange
        var middleware = CreateMiddleware(options: new GrpcWebOptions());
        var httpContext = new DefaultHttpContext();
        httpContext.Request.Protocol = "HTTP/1.1";
        httpContext.Request.Method = HttpMethods.Post;
        httpContext.Request.ContentType = GrpcWebProtocolConstants.GrpcWebContentType;
        httpContext.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableGrpcWebAttribute()), string.Empty));

        var testHttpResponseFeature = new TestHttpResponseFeature();
        httpContext.Features.Set<IHttpResponseFeature>(testHttpResponseFeature);

        // Act 1
        await middleware.Invoke(httpContext);

        // Assert 1
        Assert.AreEqual(GrpcWebProtocolConstants.GrpcContentType, httpContext.Request.ContentType);
        Assert.AreEqual(GrpcWebProtocolConstants.Http2Protocol, httpContext.Request.Protocol);
        Assert.AreEqual(1, testHttpResponseFeature.StartingCallbackCount);

        // Act 2
        httpContext.Response.ContentType = GrpcWebProtocolConstants.GrpcContentType;

        var c = testHttpResponseFeature.StartingCallbacks[0];
        await c.callback(c.state);

        // Assert 2
        Assert.AreEqual("HTTP/1.1", httpContext.Request.Protocol);
        Assert.AreEqual(GrpcWebProtocolConstants.GrpcWebContentType, httpContext.Response.ContentType);
    }

    [Test]
    public async Task Invoke_GrpcWebContentTypeAndMetadata_WriteToResponseStream_Processed()
    {
        // Arrange
        var expectedMessage = Encoding.UTF8.GetBytes("Hello world");

        var middleware = CreateMiddleware(
            options: new GrpcWebOptions { DefaultEnabled = true },
            next: c => c.GetEndpoint()!.RequestDelegate!.Invoke(c));
        var httpContext = new DefaultHttpContext();
        httpContext.Request.Protocol = "HTTP/1.1";
        httpContext.Request.Method = HttpMethods.Post;
        httpContext.Request.ContentType = GrpcWebProtocolConstants.GrpcWebContentType;
        httpContext.SetEndpoint(new Endpoint(
            c =>
            {
                c.Response.Body.Write(expectedMessage);
                c.Response.AppendTrailer("one", "two");
                return Task.CompletedTask;
            },
            new EndpointMetadataCollection(),
            string.Empty));

        var testHttpResponseFeature = new TestHttpResponseFeature();
        httpContext.Features.Set<IHttpResponseFeature>(testHttpResponseFeature);

        var ms = new MemoryStream();
        httpContext.Features.Set<IHttpResponseBodyFeature>(new TestResponseBodyFeature(PipeWriter.Create(ms)));

        // Act 1
        await middleware.Invoke(httpContext);

        // Assert 1
        Assert.AreEqual(GrpcWebProtocolConstants.GrpcContentType, httpContext.Request.ContentType);
        Assert.AreEqual(GrpcWebProtocolConstants.Http2Protocol, httpContext.Request.Protocol);
        Assert.AreEqual(1, testHttpResponseFeature.StartingCallbackCount);

        // Act 2
        httpContext.Response.ContentType = GrpcWebProtocolConstants.GrpcContentType;

        var c = testHttpResponseFeature.StartingCallbacks[0];
        await c.callback(c.state);

        // Assert 2
        Assert.AreEqual("HTTP/1.1", httpContext.Request.Protocol);
        Assert.AreEqual(GrpcWebProtocolConstants.GrpcWebContentType, httpContext.Response.ContentType);

        var bodyContent = ms.ToArray().AsMemory();

        Assert.IsTrue(bodyContent.Slice(0, expectedMessage.Length).Span.SequenceEqual(expectedMessage));
        var trailerContent = bodyContent.Slice(expectedMessage.Length);

        Assert.AreEqual(15, trailerContent.Length);

        Assert.AreEqual(128, trailerContent.Span[0]);
        Assert.AreEqual(0, trailerContent.Span[1]);
        Assert.AreEqual(0, trailerContent.Span[2]);
        Assert.AreEqual(0, trailerContent.Span[3]);
        Assert.AreEqual(10, trailerContent.Span[4]);

        var text = Encoding.ASCII.GetString(trailerContent.Span.Slice(5));

        Assert.AreEqual("one: two\r\n", text);
    }

    private static GrpcWebMiddleware CreateMiddleware(
        GrpcWebOptions? options = null,
        ILogger<GrpcWebMiddleware>? logger = null,
        RequestDelegate? next = null)
    {
        return new GrpcWebMiddleware(
            Options.Create<GrpcWebOptions>(options ?? new GrpcWebOptions()),
            logger ?? NullLogger<GrpcWebMiddleware>.Instance,
            next ?? EmptyRequestDelegate);

        static Task EmptyRequestDelegate(HttpContext context) => Task.CompletedTask;
    }
}
